from itertools import product
from os import path

import taichi as ti
from taichi.math import mix, clamp, reflect

eps = 1e-4
PI = 3.1415926535
PI2 = PI * 2
deg2rad = PI / 180
rad2deg = 1 / deg2rad
Inf = float("inf")
iVec2 = ti.types.vector(2, int)
iVec4 = ti.types.vector(4, int)
Vec2 = ti.types.vector(2, float)
Vec3 = ti.types.vector(3, float)
Vec4 = ti.types.vector(4, float)
Mat2x3 = ti.types.matrix(2, 3, float)
Mat3x2 = ti.types.matrix(3, 2, float)
Mat2x2 = ti.types.matrix(2, 2, float)
Mat3x3 = ti.types.matrix(3, 3, float)
Mat4x4 = ti.types.matrix(4, 4, float)
Mat3x3Id = Mat3x3([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
Mat4x4Id = Mat4x4([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
VecN = [None, None, Vec2, Vec3, Vec4]


@ti.dataclass
class Ray4:
    """四维光线类"""

    origin: Vec4
    direct: Vec4

    @ti.func
    def at(self, t: float) -> Vec4:
        return self.origin + self.direct * t

    @ti.func
    def atY(self, t: float) -> Vec4:
        return self.origin.y + self.direct.y * t

    @ti.func
    def ray3(self) -> "Ray3":
        """返回光线的三维版本, 三维光线求交更简便"""
        return Ray3(self.origin.xyz, self.direct.xyz)


@ti.dataclass
class Ray3:
    """三维光线类"""

    origin: Vec3
    direct: Vec3

    @ti.func
    def at(self, t: float) -> Vec3:
        return self.origin + self.direct * t

    @ti.func
    def atY(self, t: float) -> Vec3:
        return self.origin.y + self.direct.y * t

    @ti.func
    def ray4(self, t: float) -> Ray4:
        return Ray4(Vec4(self.origin, t), Vec4(self.direct, -self.direct.norm()))


@ti.func
def sqr(x):
    return x * x


@ti.func
def randUnitVec3() -> Vec3:
    """随机单位三维向量"""
    phi = ti.random() * PI2
    cTheta = ti.random() * 2 - 1
    sTheta = (1 - cTheta**2) ** 0.5
    return Vec3(ti.cos(phi) * sTheta, ti.sin(phi) * sTheta, cTheta)


@ti.func
def randUnitVec2() -> Vec2:
    """随机单位二维向量"""
    theta = ti.random() * PI2
    return Vec2(ti.cos(theta), ti.sin(theta)) * ti.random() ** 0.5


@ti.func
def fade(x: float) -> float:
    return ((x * 6 - 15) * x + 10) * x * x * x


def truePath(fp: str) -> str:
    if not path.exists(fp):
        fp = path.join(path.split(__file__)[0], fp)
    return fp


@ti.kernel
def conv(source: ti.template(), kernel: ti.template(), output: ti.template()):
    """卷积"""
    kernelSize = kernel.shape[0]
    dx = kernelSize // 2
    iMax, jMax = source.shape
    iMax, jMax = iMax - 1, jMax - 1
    for i, j in output:
        i_, j_ = i - dx, j - dx
        for m, n in ti.ndrange(kernelSize, kernelSize):
            output[i, j] += (
                source[clamp(i_ + m, 0, iMax), clamp(j_ + n, 0, jMax)] * kernel[m, n]
            )


def lorentzBoost(beta: ti.Matrix) -> ti.Matrix:
    """根据速度矢量生成洛伦兹变换矩阵"""
    beta = Vec3(beta)
    beta_squared = beta.dot(beta)
    if not beta_squared:
        return Mat4x4Id
    assert beta_squared < 1, "can't run faster than light!"
    gamma = (1 - beta_squared) ** -0.5
    lambda_0j = -gamma * beta
    lambda_ij = (gamma - 1) * beta.outer_product(beta) / beta_squared
    mat = Mat4x4(
        [
            [1, 0, 0, lambda_0j.x],
            [0, 1, 0, lambda_0j.y],
            [0, 0, 1, lambda_0j.z],
            [*lambda_0j, gamma],
        ]
    )
    for i, j in product(range(3), range(3)):
        mat[i, j] += lambda_ij[i, j]
    return mat


@ti.func
def sampleCosineHemisphere() -> Vec3:
    """余弦半球采样"""
    r = ti.sqrt(ti.random())
    phi = PI2 * ti.random()
    return Vec3(r * ti.cos(phi), r * ti.sin(phi), ti.sqrt(1 - r * r))


@ti.func
def toNormalHemisphere(dir: Vec3, Normal: Vec3) -> Vec3:
    """将三维向量投影到法线半球"""
    Bi_normal = Vec3(0, 0, 0)
    if abs(Normal.x) > abs(Normal.z):
        Bi_normal = Vec3(-Normal.y, Normal.x, 0)
    else:
        Bi_normal = Vec3(0, -Normal.z, Normal.y)
    Bi_normal = Bi_normal.normalized()
    Tangent = Bi_normal.cross(Normal).normalized()
    return dir.x * Tangent + dir.y * Bi_normal + dir.z * Normal


@ti.func
def toSpherical(v: Vec3) -> Vec2:
    """单位三维矢量转二维球坐标"""
    v = v.normalized()
    uv = Vec2(ti.asin(v.y), ti.atan2(v.z, v.x))
    return uv * Vec2(0.3183, 0.1591) + 0.5


def IdGenerator():
    """ID生成器, 用于动态生成唯一的id"""
    i = -1
    while True:
        i += 1
        yield i


@ti.func
def superSampling(n: int, t: int) -> Vec2:
    """简单的超采样"""
    return Vec2(
        (ti.random() + t / n % n) / n,
        (ti.random() + t % n) / n,
    )
